module Functions_wa_TR
    
    use Globals
    use Toolbox
    use ModuleSaving

contains

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!     Compute Cont Values  !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    subroutine ComputeVe_TR(tt_in)

    implicit none

    integer, intent(in) :: tt_in
    integer             :: SSind(NS), ix_np

    Ve_TR(:,tt_in) = 0.0D0                          ! initialize continuation value

    do ix_np = 1, Nx
        SSind = aind + Na*(ix_np-1)
        Ve_TR(:,tt_in)    = Ve_TR(:,tt_in) + Px(xind,ix_np)*V_TR(SSind,tt_in)   ! continuation value
    end do

    end subroutine ComputeVe_TR

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!     Tax Function Vector              !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function TAX_TR(yl_in,yk_in,tt_in) result(T_out)

    implicit none

    ! Local variable declarations
    real(8), intent(in)               :: yl_in(:),yk_in(:)                  ! labor income, capital income
    integer, intent(in)               :: tt_in                              ! time period
    real(8), dimension(size(yl_in,1)) :: yaux, inctax, transfers, T_out, et ! total income, income tax, transfers, net tax, auxiliary variable transfer
    real(8) :: chit                                                         ! auxiliary variable transfer

    ! total income (sum of labor and capital income)
    yaux = yl_in + yk_in     

    ! Tax
    inctax = exp(log(lambda_TR(tt_in))*((yl_in/ymean)**(-gamma)))*yl_in

    ! Transfers
    et   = exp(-rt*(yaux/ymean-omegat))
    chit = exp(rt*omegat)/(1.0D0 + exp(rt*omegat))

    transfers = (mt_TR(tt_in)*ymean)*(et/(1.0D0+et))*(1.0D0/chit)
         
    ! total net tax payment
    T_out = inctax + tauk*yk_in - transfers      

    end function TAX_TR

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!              Tax Function SCALAR     !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function TAX_sc_TR(yl_in,yk_in,tt_in) result(T_out)

    implicit none

    ! Local variable declarations
    real(8), intent(in) :: yl_in,yk_in              ! labor income, capital income
    integer, intent(in) :: tt_in                    ! time period
    real(8) :: inctax, transfers, T_out, yaux, et   ! income tax, transfers, net tax, total income, auxiliary variable transfer
    real(8) :: chit                                 ! auxiliary variable transfer
    
    ! total income (sum of labor and capital income)
    yaux = yl_in + yk_in     

    ! Tax
    inctax = exp(log(lambda_TR(tt_in))*((yl_in/ymean)**(-gamma)))*yl_in

    ! Transfers
    et   = exp(-rt*(yaux/ymean-omegat))
    chit = exp(rt*omegat)/(1.0D0 + exp(rt*omegat))

    transfers = (mt_TR(tt_in)*ymean)*(et/(1.0D0+et))*(1D0/chit)
     
    ! total net tax payment
    T_out = inctax + tauk*yk_in - transfers 

    end function TAX_sc_TR


    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!              Tax Function SCALAR     !!!
    !!!                 der wrt yl           !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function dTAX_sc_TR(yl_in,yk_in,tt_in) result(T_out)

    implicit none

    real(8), intent(in) :: yl_in,yk_in
    integer, intent(in) :: tt_in
    real(8) :: inctaxl, transfersl, T_out, yaux, et, chit

    ! total income (sum of labor and capital income)
    yaux = yl_in +yk_in     

    ! Tax
    inctaxl = exp(log(lambda_TR(tt_in))*((yl_in/ymean)**(-gamma)))*(1d0-gamma*log(lambda_TR(tt_in))*((yl_in/ymean)**(-gamma)))

    ! Transfers
    et   = exp(-rt*(yaux/ymean-omegat))
    chit = exp(rt*omegat)/(1.0D0+exp(rt*omegat))
    transfersl = -(rt/ymean)*(mt_TR(tt_in)*ymean)*(et/((1.0D0+et)**2d0))*(1D0/chit)

    ! total net tax payment
    T_out = inctaxl - transfersl 

    end function dTAX_sc_TR


    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!              Tax Function SCALAR     !!!
    !!!                der2 wrt yl           !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function ddTAX_sc_TR(yl_in,yk_in,tt_in) result(T_out)

    implicit none

    real(8), intent(in) :: yl_in,yk_in
    integer, intent(in) :: tt_in
    real(8) :: inctaxll, transfersll, T_out, yaux, et, chit

    ! total income (sum of labor and capital income)
    yaux = yl_in +yk_in     

    ! Tax
    inctaxll = exp(log(lambda_TR(tt_in))*((yl_in/ymean)**(-gamma)))*gamma*log(lambda_TR(tt_in))*((yl_in/ymean)**(-gamma-1.0D0))*(1d0/ymean)
    inctaxll = inctaxll*(gamma-1D0+gamma*log(lambda_TR(tt_in))*((yl_in/ymean)**(-gamma)))

    ! Transfers
    et = exp(-rt*(yaux/ymean-omegat))
    chit = exp(rt*omegat)/(1.0D0 + exp(rt*omegat))
    transfersll = ((-rt/ymean)**2D0)*(mt_TR(tt_in)*ymean)*(et/((1.0D0+et)**3d0))*(1d0-et)*(1D0/chit)

    ! total net tax
    T_out = inctaxll - transfersll     

    end function ddTAX_sc_TR


    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!    For a given scalar (is_in)               !!!
    !!!    Given a, compute V                       !!!
    !!!    hopt found with grid + Golden            !!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function VF_eval_wa_hopt_sc_TR(a_in,is_in,save_dum,tt_in) result(V_out)

    implicit none

    ! Local variable declarations
    real(8), intent(in) :: a_in
    integer, intent(in) :: is_in, save_dum, tt_in

    real(8) :: V_out, Vh(Nh), uuh(Nh), soc(Nh)
    real(8) :: yk, yl, ylaux, ytot, caux, uu, Ve_np, haux, h_in, soc_aux

    real(8) :: anp, anp_vec(1), aaux
    integer :: aind_np, aind_np_vec(1), aux(1), aux_int
    integer :: Sind_np, ih

    ! Asset choice
    anp = a_in

    ! Compute value with linear interpolation of continuation value in asset direction
    anp_vec = anp
    !aind_np_vec = vsearch_FET(anp_vec,a_vec)
    aind_np_vec = vsearchprm_FET(anp_vec,a_vec,aprm)
    aind_np = aind_np_vec(1)
    Sind_np = aind_np + Na*(xind(is_in)-1)

    Ve_np = Ve_TR(Sind_np,tt_in+1) + (anp - a_vec(aind_np)) * ( Ve_TR(Sind_np+1,tt_in+1) - Ve_TR(Sind_np,tt_in+1) )/(a_vec(aind_np+1) - a_vec(aind_np))
    Ve_np = min(Ve_np,1D-8)

    ! Income and assets
    yK    = r_TR(tt_in)*S(is_in,1)      ! capital income
    yLaux = wge_TR(tt_in)*S(is_in,2)    ! income (wage times productivity) per unit of work
    aaux  = S(is_in,1)                  ! assets

    ! h on a grid
    do ih = 1, Nh
        h_in = h_vec(ih)    ! pick h
        yL   = yLaux*h_in   ! labor income for that hours choice  
        ytot = yl + yk      ! total income for that hours choice

        caux = (aaux + yK + yL - TAX_sc_TR(yl,yk,tt_in) - anp)/(1d0+tauc)    ! consumption for that hours choice

        if (caux>0.0D0) then    ! check that hours choice implies feasible consumption
            soc_aux = -sg*((caux**(-sg-1d0))/(1d0+tauc))*(yLaux**2D0)*((1d0-dTAX_sc_TR(yl,yk,tt_in))**2D0)
            soc_aux = soc_aux - ((caux**(-sg))/(1d0+tauc))*(yLaux**2D0)*ddTAX_sc_TR(yl,yk,tt_in)
            soc(ih) = soc_aux - varphi*B*(h_in**(varphi-1d0))
            uuh(ih) = ((caux**(1.0D0-sg))/(1.0D0-sg)) - B*(((h_in)**(1.0D0+varphi))/(1.0D0+varphi))
        else
            soc(ih) = -1D0
            uuh(ih) = -1000D0
        endif

    enddo

    ! Pick best labor supply on grid
    aux = maxloc(uuh)
    aux_int = aux(1)

    ! Initialize haux to call golden if Newton not called
    haux = -2d0

    ! if soc<0, use Newton
    if ((maxval(soc)<0.0D0)) then ! issues when starting at h=0
        haux = Newton_uh_sc_TR(anp,h_vec(max(aux_int,2)),is_in,tt_in) 
    endif

    ! else, use Golden (if Newton ruled out by SOC; Newton stuck; or Newton result outside of bounds)
    if ( .NOT. ( (h_vec(max(aux_int-1,1)) <= haux) .AND. (haux <= h_vec(min(Nh,aux_int+1)) ) ) ) then ! =-2 if maxval soc >0, =-1 if stuck in Newton
        haux = SolveBell_uh_sc_TR(anp,h_vec(max(aux_int-1,1)),h_vec(min(Nh,aux_int+1)),is_in,tt_in) 
    endif

    ! Compute utility
    yL   = yLaux*haux                                       ! labor income
    caux = (aaux + yK + yL - TAX_sc_TR(yL,yK,tt_in) - anp )/(1d0+tauc)   ! consumption
    caux = max(caux,1D-12)                                  ! consumption
    uu   = ((caux**(1.0D0-sg))/(1.0D0-sg)) - B*(((haux)**(1.0D0+varphi))/(1.0D0+varphi))    ! instantaneous utility

    V_out = uu + beta*Ve_np

    ! Save
    if (save_dum==1) then 
        h_pol_TR(is_in,tt_in) = haux
        c_pol_TR(is_in,tt_in) = caux
    endif
    if (save_dum>1) then
        print *, 'Error algorithm'
    endif

    end function VF_eval_wa_hopt_sc_TR

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!!!  Golden Search [SCALAR] on a                     !!!!!
    !!!!!  h is computed optimally [NESTED GOLDEN]         !!!!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function SolveBell_wa_hopt_sc_TR(xl1,xr1,is_in,tt_in) result(x_out)

    implicit none

    real(8), intent(in)  :: xl1, xr1        ! bounds for golden search
    integer, intent(in)  :: is_in, tt_in    ! state, time
    real(8) :: x_out                        ! output: asset choice
    real(8) :: xl0, xr0

    real(8) :: alpha1, alpha2
    real(8) :: d, x1, x2, f1, f2, x1new, x2new, f1new, f2new

    real(8) :: tol = 1D-10

    alpha1 = 0.5D0*(3.0D0 - sqrt(5.0D0))
    alpha2 = 0.5D0*(sqrt(5.0D0) - 1.0D0)

    xl0 = xl1
    xr0 = xr1

    if (xr0 < xl0) then
        xl0 = xr0
    endif

    d  = xr0 - xl0

    x1 = xl0 + alpha1*d
    x2 = xl0 + alpha2*d

    f1 = VF_eval_wa_hopt_sc_TR(x1,is_in,0,tt_in)
    f2 = VF_eval_wa_hopt_sc_TR(x2,is_in,0,tt_in)

    d  = alpha1*alpha2*d

    f1new = f1
    f2new = f2
    x1new = x1
    x2new = x2

    do while (d > tol)

    f1 = f1new
    f2 = f2new
    x1 = x1new
    x2 = x2new

    d = d*alpha2

    if (f2 < f1 ) then

        x2new = x1
        x1new = x1 - d

        f2new = f1
        f1new = VF_eval_wa_hopt_sc_TR(x1-d,is_in,0,tt_in)

    else

        x2new = x2 + d
        x1new = x2

        f2new = VF_eval_wa_hopt_sc_TR(x2+d,is_in,0,tt_in)
        f1new = f2

    endif

    end do

    if ( f2 < f1) then
        x_out = x1
    else
        x_out = x2
    endif

    end function SolveBell_wa_hopt_sc_TR

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!!!  Golden Search [SCALAR] on h, static opt on U     !!!!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function SolveBell_uh_sc_TR(a_in,xl1,xr1,is_in,tt_in) result(x_out)

    implicit none

    ! Local variable declarations
    real(8), intent(in)  :: a_in, xl1, xr1      ! given asset choice, bounds for h
    integer, intent(in)  :: is_in, tt_in        ! state, time
    real(8) :: x_out                            ! output: labor choice
    real(8) :: xl0, xr0

    real(8) :: alpha1, alpha2
    real(8) :: d, x1, x2, f1, f2, x1new, x2new, f1new, f2new

    real(8) :: tol = 1D-10

    real(8) :: h_in, yK, yLaux, yL, ytot, aaux
    real(8) :: caux, u_out

    aaux  = S(is_in,1)                  ! assets 
    yK    = r_TR(tt_in)*aaux            ! capital income
    yLaux = wge_TR(tt_in)*S(is_in,2)    ! income (wage times productivity) per unit of work

    alpha1 = 0.5D0*(3.0D0 - sqrt(5.0D0))
    alpha2 = 0.5D0*(sqrt(5.0D0) - 1.0D0)

    xl0 = xl1
    xr0 = xr1

    if (xr0 < xl0) then
        xl0 = xr0
    endif

    d  = xr0 - xl0

    x1 = xl0 + alpha1*d
    x2 = xl0 + alpha2*d

    ! First evaluation, h = x1
    h_in = x1
    yL   = yLaux*h_in
    ytot = yL + yK

    caux = (aaux + ytot - TAX_sc_TR(yl,yk,tt_in) - a_in)/(1d0+tauc)
    caux = max(caux,1D-12)
    u_out = ( (caux**(1.0D0-sg))/(1.0D0-sg) ) - B*( ((h_in)**(1.0D0+varphi))/(1.0D0+varphi) )

    f1 = u_out

    ! Second evaluation, h  = x2
    h_in = x2
    yL   = yLaux*h_in
    ytot = yL + yK

    caux = (aaux + ytot - TAX_sc_TR(yl,yk,tt_in) - a_in)/(1d0+tauc)
    caux = max(caux,1D-12)
    u_out = ( (caux**(1.0D0-sg))/(1.0D0-sg) ) - B*( ((h_in)**(1.0D0+varphi))/(1.0D0+varphi) )

    f2 = u_out

    d  = alpha1*alpha2*d

    f1new = f1
    f2new = f2
    x1new = x1
    x2new = x2

    do while (d > tol)

    f1 = f1new
    f2 = f2new
    x1 = x1new
    x2 = x2new

    d = d*alpha2

    if (f2 < f1 ) then

        x2new = x1
        x1new = x1 - d

        f2new = f1
        ! First evaluation, h = x1-d
        h_in = x1-d
        yL   = yLaux*h_in
        ytot = yL + yK

        caux = (aaux + ytot - TAX_sc_TR(yl,yk,tt_in) - a_in)/(1d0+tauc)
        caux = max(caux,1D-12)
        u_out = ( (caux**(1.0D0-sg))/(1.0D0-sg) ) - B*( ((h_in)**(1.0D0+varphi))/(1.0D0+varphi) )

        f1new = u_out

    else

        x2new = x2 + d
        x1new = x2

        ! Second evaluation, h = x2+d
        h_in = (x2+d)
        yL   = yLaux*h_in
        ytot = yL + yK

        caux = (aaux + ytot - TAX_sc_TR(yl,yk,tt_in) - a_in)/(1d0+tauc)
        caux = max(caux,1D-12)
        u_out = ( (caux**(1.0D0-sg))/(1.0D0-sg) ) - B*( ((h_in)**(1.0D0+varphi))/(1.0D0+varphi) )

        f2new = u_out
        f1new = f2

    endif

    end do

    if ( f2 < f1) then
        x_out = x1
    else
        x_out = x2
    endif

    end function SolveBell_uh_sc_TR


    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    !!!!!  NEWTON [SCALAR] on h, static opt on U            !!!!!
    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    function Newton_uh_sc_TR(a_in,h_in,is_in,tt_in) result(h_out)

    implicit none

    real(8), intent(in)  :: a_in, h_in      ! given asset input, starting point for h
    integer, intent(in)  :: is_in, tt_in    ! state, time
    real(8) :: h_out                        ! output: labor choice

    real(8) :: h_init, h_end, tol, err, aux
    integer :: iter, maxiter

    real(8) :: caux, yK, yLaux, yL, ytot, et, chit, aaux
    real(8) :: foc_out, soc_out

    tol = 1D-5      ! tolerance
    err = 1.0D0     ! initialize error
    iter = 0        ! initialize iteration counter
    maxiter = 30    ! max number of iterations

    aaux  = S(is_in,1)                  ! assets
    yK    = r_TR(tt_in)*aaux            ! capital income
    yLaux = wge_TR(tt_in)*S(is_in,2)    ! income (wage times productivity) per hour worked

    h_init = h_in                       ! starting point for hours

    do while ((err > tol) .and. (iter<maxiter))

        yL   = yLaux*h_init     ! labor income
        ytot = yL + yK          ! total income

        caux = (aaux + yL + yK - TAX_sc_TR(yL,yK,tt_in) - a_in)/(1d0+tauc)   ! consumption
        caux = max(caux,1D-12)                                  ! consumption

        ! FOC
        foc_out = ((caux**(-sg))/(1d0+tauc)) * yLaux * (1d0-dTAX_sc_TR(yl,yk,tt_in))
        foc_out = foc_out - B* (h_init**(varphi))

        ! SOC
        soc_out = -sg*((caux**(-sg-1d0))/(1d0+tauc)) * (yLaux**2D0) * ((1d0-dTAX_sc_TR(yl,yk,tt_in))**2D0)
        soc_out = soc_out - ((caux**(-sg))/(1d0+tauc)) * (yLaux**2D0) * ddTAX_sc_TR(yl,yk,tt_in)
        soc_out = soc_out - varphi*B* (h_init**(varphi-1d0))

        ! Newton step
        h_end = h_init - FOC_out/SOC_out
        err = abs(h_init-h_end)
        iter = iter + 1
        h_init = h_end

    enddo

    ! If Newton is not successful, set h_out such that golden is triggered
    if (iter>=maxiter-1) then
        h_out = -1.0D0
    else
        h_out = h_init
    endif
    
    end function Newton_uh_sc_TR

end module Functions_wa_TR
